{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM Meta Learner on MiniImageNet Dataset\n",
    "\n",
    "Please Download data using link and save it in save folder as of this notebook after extracting: https://drive.google.com/file/d/1rV3aj_hgfNTfCakffpPm7Vhpr1in87CR\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#import libraries\n",
    "from __future__ import division, print_function, absolute_import\n",
    "import os\n",
    "import re\n",
    "import pdb\n",
    "import copy\n",
    "import glob\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import PIL.Image as PILI\n",
    "from tqdm import tqdm\n",
    "from collections import OrderedDict\n",
    "import random\n",
    "import logging\n",
    "import os\n",
    "from torchvision.datasets.utils import download_file_from_google_drive, extract_archive"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 1: Data Loader "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class EpisodeDataset(data.Dataset):\n",
    "\n",
    "    def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):\n",
    "        \"\"\"Args:\n",
    "            root (str): path to data\n",
    "            phase (str): train, val or test\n",
    "            n_shot (int): how many examples per class for training (k/n_support)\n",
    "            n_eval (int): how many examples per class for evaluation\n",
    "                - n_shot + n_eval = batch_size for data.DataLoader of ClassDataset\n",
    "            transform (torchvision.transforms): data augmentation\n",
    "        \"\"\"\n",
    "        root = os.path.join(root, phase)\n",
    "        self.labels = sorted(os.listdir(root))\n",
    "        images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]\n",
    "        self.episode_loader = [data.DataLoader(\n",
    "            ClassDataset(images=images[idx], label=idx, transform=transform),\n",
    "            batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return next(iter(self.episode_loader[idx]))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ClassDataset(data.Dataset):\n",
    "\n",
    "    def __init__(self, images, label, transform=None):\n",
    "        \"\"\"Args:\n",
    "            images (list of str): each item is a path to an image of the same label\n",
    "            label (int): the label of all the images\n",
    "        \"\"\"\n",
    "        self.images = images\n",
    "        self.label = label\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        image = PILI.open(self.images[idx]).convert('RGB')\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "\n",
    "        return image, self.label\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class EpisodicSampler(data.Sampler):\n",
    "\n",
    "    def __init__(self, total_classes, n_class, n_episode):\n",
    "        self.total_classes = total_classes\n",
    "        self.n_class = n_class\n",
    "        self.n_episode = n_episode\n",
    "\n",
    "    def __iter__(self):\n",
    "        for i in range(self.n_episode):\n",
    "            yield torch.randperm(self.total_classes)[:self.n_class]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.n_episode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prepare_data(args):\n",
    "\n",
    "    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "    \n",
    "    transform1= transforms.Compose([\n",
    "            transforms.Resize(args['image_size'] * 8 // 7),\n",
    "            transforms.CenterCrop(args['image_size']),\n",
    "            transforms.ToTensor(),\n",
    "            normalize])\n",
    "    \n",
    "    transform2 = transforms.Compose([\n",
    "            transforms.RandomResizedCrop(args['image_size']),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ColorJitter(\n",
    "                brightness=0.4,\n",
    "                contrast=0.4,\n",
    "                saturation=0.4,\n",
    "                hue=0.2),\n",
    "            transforms.ToTensor(),\n",
    "            normalize])\n",
    "    \n",
    "    train_set = EpisodeDataset(args['data_root'], 'train', args['n_shot'], args['n_eval'],\n",
    "        transform=transform2)\n",
    "\n",
    "    val_set = EpisodeDataset(args['data_root'], 'val', args['n_shot'], args['n_eval'],\n",
    "        transform=transform1)\n",
    "\n",
    "    test_set = EpisodeDataset(args['data_root'], 'test', args['n_shot'], args['n_eval'],\n",
    "        transform=transform1)\n",
    "\n",
    "    train_loader = data.DataLoader(train_set, num_workers=0, pin_memory=True,\n",
    "        batch_sampler=EpisodicSampler(len(train_set), args['n_class'], args['episode']))\n",
    "\n",
    "    val_loader = data.DataLoader(val_set, num_workers=0, pin_memory=True,\n",
    "        batch_sampler=EpisodicSampler(len(val_set), args['n_class'], args['episode_val']))\n",
    "\n",
    "    test_loader = data.DataLoader(test_set, num_workers=0, pin_memory=True,\n",
    "        batch_sampler=EpisodicSampler(len(test_set), args['n_class'], args['episode_val']))\n",
    "    return train_loader, val_loader, test_loader\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 2: Learner "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Learner(nn.Module):\n",
    "\n",
    "    def __init__(self, image_size, bn_eps, bn_momentum, n_classes):\n",
    "        super(Learner, self).__init__()\n",
    "        self.model = nn.ModuleDict({'features': nn.Sequential(OrderedDict([\n",
    "            ('conv1', nn.Conv2d(3, 32, 3, padding=1)),\n",
    "            ('norm1', nn.BatchNorm2d(32, bn_eps, bn_momentum)),\n",
    "            ('relu1', nn.ReLU(inplace=False)),\n",
    "            ('pool1', nn.MaxPool2d(2)),\n",
    "\n",
    "            ('conv2', nn.Conv2d(32, 32, 3, padding=1)),\n",
    "            ('norm2', nn.BatchNorm2d(32, bn_eps, bn_momentum)),\n",
    "            ('relu2', nn.ReLU(inplace=False)),\n",
    "            ('pool2', nn.MaxPool2d(2)),\n",
    "\n",
    "            ('conv3', nn.Conv2d(32, 32, 3, padding=1)),\n",
    "            ('norm3', nn.BatchNorm2d(32, bn_eps, bn_momentum)),\n",
    "            ('relu3', nn.ReLU(inplace=False)),\n",
    "            ('pool3', nn.MaxPool2d(2)),\n",
    "\n",
    "            ('conv4', nn.Conv2d(32, 32, 3, padding=1)),\n",
    "            ('norm4', nn.BatchNorm2d(32, bn_eps, bn_momentum)),\n",
    "            ('relu4', nn.ReLU(inplace=False)),\n",
    "            ('pool4', nn.MaxPool2d(2))]))\n",
    "        })\n",
    "\n",
    "        clr_in = image_size // 2**4\n",
    "        self.model.update({'cls': nn.Linear(32 * clr_in * clr_in, n_classes)})\n",
    "        self.criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.model.features(x)\n",
    "        x = torch.reshape(x, [x.size(0), -1])\n",
    "        outputs = self.model.cls(x)\n",
    "        return outputs\n",
    "\n",
    "    def get_flat_params(self):\n",
    "        return torch.cat([p.view(-1) for p in self.model.parameters()], 0)\n",
    "\n",
    "    def copy_flat_params(self, cI):\n",
    "        idx = 0\n",
    "        for p in self.model.parameters():\n",
    "            plen = p.view(-1).size(0)\n",
    "            p.data.copy_(cI[idx: idx+plen].view_as(p))\n",
    "            idx += plen\n",
    "\n",
    "    def transfer_params(self, learner_w_grad, cI):\n",
    "        # Use load_state_dict only to copy the running mean/var in batchnorm, the values of the parameters\n",
    "        #  are going to be replaced by cI\n",
    "        self.load_state_dict(learner_w_grad.state_dict())\n",
    "        #  replace nn.Parameters with tensors from cI (NOT nn.Parameters anymore).\n",
    "        idx = 0\n",
    "        for m in self.model.modules():\n",
    "            if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):\n",
    "                wlen = m._parameters['weight'].view(-1).size(0)\n",
    "                m._parameters['weight'] = cI[idx: idx+wlen].view_as(m._parameters['weight']).clone()\n",
    "                idx += wlen\n",
    "                if m._parameters['bias'] is not None:\n",
    "                    blen = m._parameters['bias'].view(-1).size(0)\n",
    "                    m._parameters['bias'] = cI[idx: idx+blen].view_as(m._parameters['bias']).clone()\n",
    "                    idx += blen\n",
    "\n",
    "    def reset_batch_stats(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.BatchNorm2d):\n",
    "                m.reset_running_stats()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 4: Meta Learner "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class MetaLSTMCell(nn.Module):\n",
    "    \"\"\"C_t = f_t * C_{t-1} + i_t * \\tilde{C_t}\"\"\"\n",
    "    def __init__(self, input_size, hidden_size, n_learner_params):\n",
    "        super(MetaLSTMCell, self).__init__()\n",
    "        \"\"\"Args:\n",
    "            input_size (int): cell input size, default = 20\n",
    "            hidden_size (int): should be 1\n",
    "            n_learner_params (int): number of learner's parameters\n",
    "        \"\"\"\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.n_learner_params = n_learner_params\n",
    "        self.WF = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))\n",
    "        self.WI = nn.Parameter(torch.Tensor(input_size + 2, hidden_size))\n",
    "        self.cI = nn.Parameter(torch.Tensor(n_learner_params, 1))\n",
    "        self.bI = nn.Parameter(torch.Tensor(1, hidden_size))\n",
    "        self.bF = nn.Parameter(torch.Tensor(1, hidden_size))\n",
    "\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        for weight in self.parameters():\n",
    "            nn.init.uniform_(weight, -0.01, 0.01)\n",
    "\n",
    "        # want initial forget value to be high and input value to be low so that \n",
    "        #  model starts with gradient descent\n",
    "        nn.init.uniform_(self.bF, 4, 6)\n",
    "        nn.init.uniform_(self.bI, -5, -4)\n",
    "\n",
    "    def init_cI(self, flat_params):\n",
    "        self.cI.data.copy_(flat_params.unsqueeze(1))\n",
    "\n",
    "    def forward(self, inputs, hx=None):\n",
    "        \"\"\"Args:\n",
    "            inputs = [x_all, grad]:\n",
    "                x_all (torch.Tensor of size [n_learner_params, input_size]): outputs from previous LSTM\n",
    "                grad (torch.Tensor of size [n_learner_params]): gradients from learner\n",
    "            hx = [f_prev, i_prev, c_prev]:\n",
    "                f (torch.Tensor of size [n_learner_params, 1]): forget gate\n",
    "                i (torch.Tensor of size [n_learner_params, 1]): input gate\n",
    "                c (torch.Tensor of size [n_learner_params, 1]): flattened learner parameters\n",
    "        \"\"\"\n",
    "        x_all, grad = inputs\n",
    "        batch, _ = x_all.size()\n",
    "\n",
    "        if hx is None:\n",
    "            f_prev = torch.zeros((batch, self.hidden_size)).to(self.WF.device)\n",
    "            i_prev = torch.zeros((batch, self.hidden_size)).to(self.WI.device)\n",
    "            c_prev = self.cI\n",
    "            hx = [f_prev, i_prev, c_prev]\n",
    "\n",
    "        f_prev, i_prev, c_prev = hx\n",
    "        \n",
    "        # f_t = sigmoid(W_f * [grad_t, loss_t, theta_{t-1}, f_{t-1}] + b_f)\n",
    "        f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.WF) + self.bF.expand_as(f_prev)\n",
    "        # i_t = sigmoid(W_i * [grad_t, loss_t, theta_{t-1}, i_{t-1}] + b_i)\n",
    "        i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.WI) + self.bI.expand_as(i_prev)\n",
    "        # next cell/params\n",
    "        c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)\n",
    "\n",
    "        return c_next, [f_next, i_next, c_next]\n",
    "\n",
    "    def extra_repr(self):\n",
    "        s = '{input_size}, {hidden_size}, {n_learner_params}'\n",
    "        return s.format(**self.__dict__)\n",
    "\n",
    "\n",
    "class MetaLearner(nn.Module):\n",
    "\n",
    "    def __init__(self, input_size, hidden_size, n_learner_params):\n",
    "        super(MetaLearner, self).__init__()\n",
    "        \"\"\"Args:\n",
    "            input_size (int): for the first LSTM layer, default = 4\n",
    "            hidden_size (int): for the first LSTM layer, default = 20\n",
    "            n_learner_params (int): number of learner's parameters\n",
    "        \"\"\"\n",
    "        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)\n",
    "        self.metalstm = MetaLSTMCell(input_size=hidden_size, hidden_size=1, n_learner_params=n_learner_params)\n",
    "\n",
    "    def forward(self, inputs, hs=None):\n",
    "        \"\"\"Args:\n",
    "            inputs = [loss, grad_prep, grad]\n",
    "                loss (torch.Tensor of size [1, 2])\n",
    "                grad_prep (torch.Tensor of size [n_learner_params, 2])\n",
    "                grad (torch.Tensor of size [n_learner_params])\n",
    "            hs = [(lstm_hn, lstm_cn), [metalstm_fn, metalstm_in, metalstm_cn]]\n",
    "        \"\"\"\n",
    "        loss, grad_prep, grad = inputs\n",
    "        loss = loss.expand_as(grad_prep)\n",
    "        inputs = torch.cat((loss, grad_prep), 1)   # [n_learner_params, 4]\n",
    "\n",
    "        if hs is None:\n",
    "            hs = [None, None]\n",
    "\n",
    "        lstmhx, lstmcx = self.lstm(inputs, hs[0])\n",
    "        flat_learner_unsqzd, metalstm_hs = self.metalstm([lstmhx, grad], hs[1])\n",
    "\n",
    "        return flat_learner_unsqzd.squeeze(), [(lstmhx, lstmcx), metalstm_hs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 4: utils "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def accuracy(output, target, topk=(1,)):\n",
    "    with torch.no_grad():\n",
    "        maxk = max(topk)\n",
    "        batch_size = target.size(0)\n",
    "\n",
    "        _, pred = output.topk(maxk, 1, True, True)\n",
    "        pred = pred.t()\n",
    "        correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
    "\n",
    "        res = []\n",
    "        for k in topk:\n",
    "            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n",
    "            res.append(correct_k.mul_(100.0 / batch_size))\n",
    "        return res[0].item() if len(res) == 1 else [r.item() for r in res]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def preprocess_grad_loss(x):\n",
    "    p = 10\n",
    "    indicator = (x.abs() >= np.exp(-p)).to(torch.float32)\n",
    "\n",
    "    # preproc1\n",
    "    x_proc1 = indicator * torch.log(x.abs() + 1e-8) / p + (1 - indicator) * -1\n",
    "    # preproc2\n",
    "    x_proc2 = indicator * torch.sign(x) + (1 - indicator) * np.exp(p) * x\n",
    "    return torch.stack((x_proc1, x_proc2), 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 5: Main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def meta_test(eps, eval_loader, learner_w_grad, learner_wo_grad, metalearner, args):\n",
    "    for subeps, (episode_x, episode_y) in enumerate(tqdm(eval_loader, ascii=True)):\n",
    "        train_input = episode_x[:, :args['n_shot']].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_shot, :]\n",
    "        train_target = torch.LongTensor(np.repeat(range(args['n_class']), args['n_shot'])) # [n_class * n_shot]\n",
    "        test_input = episode_x[:, args['n_shot']:].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_eval, :]\n",
    "        test_target = torch.LongTensor(np.repeat(range(args['n_class']), args['n_eval'])) # [n_class * n_eval]\n",
    "\n",
    "        # Train learner with metalearner\n",
    "        learner_w_grad.reset_batch_stats()\n",
    "        learner_wo_grad.reset_batch_stats()\n",
    "        learner_w_grad.train()\n",
    "        learner_wo_grad.eval()\n",
    "        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)\n",
    "\n",
    "        learner_wo_grad.transfer_params(learner_w_grad, cI)\n",
    "        output = learner_wo_grad(test_input)\n",
    "        loss = learner_wo_grad.criterion(output, test_target)\n",
    "        acc = accuracy(output, test_target)\n",
    "        print (\"Validation/Test Loss: {}, and Accuracy {}\".format(loss.item(), acc))\n",
    "        \n",
    "\n",
    "    return acc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train_learner(learner_w_grad, metalearner, train_input, train_target, args):\n",
    "    cI = metalearner.metalstm.cI.data\n",
    "    hs = [None]\n",
    "    for _ in range(args['epoch']):\n",
    "        for i in range(0, len(train_input), args['batch_size']):\n",
    "            x = train_input[i:i+args['batch_size']]\n",
    "            y = train_target[i:i+args['batch_size']]\n",
    "\n",
    "            # get the loss/grad\n",
    "            learner_w_grad.copy_flat_params(cI)\n",
    "            output = learner_w_grad(x)\n",
    "            loss = learner_w_grad.criterion(output, y)\n",
    "            acc = accuracy(output, y)\n",
    "            learner_w_grad.zero_grad()\n",
    "            loss.backward()\n",
    "            grad = torch.cat([p.grad.data.view(-1) / args['batch_size'] for p in learner_w_grad.parameters()], 0)\n",
    "\n",
    "            # preprocess grad & loss and metalearner forward\n",
    "            grad_prep = preprocess_grad_loss(grad)  # [n_learner_params, 2]\n",
    "            loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0)) # [1, 2]\n",
    "            metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)]\n",
    "            cI, h = metalearner(metalearner_input, hs[-1])\n",
    "            hs.append(h)\n",
    "            print(\"training loss: {:8.6f} acc: {:6.3f}, mean grad: {:8.6f}\".format(loss, acc, torch.mean(grad)))\n",
    "\n",
    "    return cI\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def main(args):\n",
    "    seed = 2019\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    # Get data\n",
    "    train_loader, val_loader, test_loader = prepare_data(args)\n",
    "    \n",
    "    # Set up learner, meta-learner\n",
    "    learner_w_grad = Learner(args['image_size'], args['bn_eps'], args['bn_momentum'], args['n_class'])\n",
    "    learner_wo_grad = copy.deepcopy(learner_w_grad)\n",
    "    metalearner = MetaLearner(args['input_size'], args['hidden_size'], learner_w_grad.get_flat_params().size(0))\n",
    "    metalearner.metalstm.init_cI(learner_w_grad.get_flat_params())\n",
    "\n",
    "    # Set up loss, optimizer, learning rate scheduler\n",
    "    optim = torch.optim.Adam(metalearner.parameters(), args['lr'])\n",
    "\n",
    "    if args['mode'] == 'test':\n",
    "        _ = meta_test(args['episode'], test_loader, learner_w_grad, learner_wo_grad, metalearner, args)\n",
    "        return\n",
    "    best_acc = 0.0\n",
    "    # Meta-training\n",
    "    for eps, (episode_x, episode_y) in enumerate(train_loader):\n",
    "        # episode_x.shape = [n_class, n_shot + n_eval, c, h, w]\n",
    "        # episode_y.shape = [n_class, n_shot + n_eval] --> NEVER USED\n",
    "        train_input = episode_x[:, :args['n_shot']].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_shot, :]\n",
    "        train_target = torch.LongTensor(np.repeat(range(args['n_shot']), args['n_shot'])) # [n_class * n_shot]\n",
    "        test_input = episode_x[:, args['n_shot']:].reshape(-1, *episode_x.shape[-3:]) # [n_class * n_eval, :]\n",
    "        test_target = torch.LongTensor(np.repeat(range(args['n_shot']), args['n_eval'])) # [n_class * n_eval]\n",
    "\n",
    "        # Train learner with metalearner\n",
    "        learner_w_grad.reset_batch_stats()\n",
    "        learner_wo_grad.reset_batch_stats()\n",
    "        learner_w_grad.train()\n",
    "        learner_wo_grad.train()\n",
    "        cI = train_learner(learner_w_grad, metalearner, train_input, train_target, args)\n",
    "\n",
    "        # Train meta-learner with validation loss\n",
    "        learner_wo_grad.transfer_params(learner_w_grad, cI)\n",
    "        output = learner_wo_grad(test_input)\n",
    "        loss = learner_wo_grad.criterion(output, test_target)\n",
    "        acc = accuracy(output, test_target)\n",
    "        \n",
    "        optim.zero_grad()\n",
    "        loss.backward()\n",
    "        nn.utils.clip_grad_norm_(metalearner.parameters(), args['grad_clip'])\n",
    "        optim.step()\n",
    "        \n",
    "        print (eps)\n",
    "        # Meta-validation\n",
    "        if eps % 10 == 0 and eps != 0:\n",
    "            acc = meta_test(eps, val_loader, learner_w_grad, learner_wo_grad, metalearner, args)\n",
    "            if acc > best_acc:\n",
    "                best_acc = acc\n",
    "                print (\"* Best accuracy so far *\\n\")\n",
    "\n",
    "    print(\"Done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " BEGIN TRAINING: \n",
      "training loss: 1.668335 acc: 25.000, mean grad: 0.000055\n",
      "training loss: 1.906144 acc: 11.111, mean grad: -0.000132\n",
      "training loss: 1.607665 acc: 31.250, mean grad: 0.000069\n",
      "training loss: 1.632667 acc: 11.111, mean grad: -0.000136\n",
      "training loss: 1.566477 acc: 31.250, mean grad: 0.000058\n",
      "training loss: 1.440862 acc: 22.222, mean grad: -0.000116\n",
      "training loss: 1.531511 acc: 25.000, mean grad: 0.000073\n",
      "training loss: 1.299537 acc: 44.444, mean grad: -0.000100\n",
      "training loss: 1.498922 acc: 25.000, mean grad: 0.000054\n",
      "training loss: 1.192788 acc: 66.667, mean grad: -0.000099\n",
      "training loss: 1.468537 acc: 18.750, mean grad: 0.000059\n",
      "training loss: 1.111149 acc: 66.667, mean grad: -0.000090\n",
      "training loss: 1.440375 acc: 25.000, mean grad: 0.000043\n",
      "training loss: 1.046557 acc: 66.667, mean grad: -0.000107\n",
      "training loss: 1.413390 acc: 31.250, mean grad: 0.000042\n",
      "training loss: 0.993363 acc: 77.778, mean grad: -0.000104\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "if __name__ == '__main__':\n",
    "    args_train={'mode':'train','n_shot':5,'n_eval':15,'n_class':5,'input_size':4,'hidden_size':20,'lr':1e-3,'episode':50,\n",
    "      'episode_val':50,'epoch':8,'batch_size':16,'image_size':84,'grad_clip':0.25,'bn_momentum': 0.95,'bn_eps': 1e-3,\n",
    "       'data': \"miniimagenet\",'data_root': \"./miniImagenet/\", 'resume': None}\n",
    "    \n",
    "    \n",
    "    args_test={'mode':'test','n_shot':5,'n_eval':15,'n_class':5,'input_size':4,'hidden_size':20,'lr':1e-3,'episode':50,\n",
    "      'episode_val':10,'epoch':8,'batch_size':16,'image_size':84,'grad_clip':0.25,'bn_momentum': 0.95,'bn_eps': 1e-3,\n",
    "       'data': \"miniimagenet\",'data_root': \"./miniImagenet/\", 'resume': None}\n",
    "    \n",
    "    \n",
    "    print (\" BEGIN TRAINING: \")\n",
    "    main(args_train)\n",
    "    \n",
    "    \n",
    "    print (\"BEGIN TESTING\")\n",
    "    main(args_test)\n",
    "    \n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
